#!/usr/bin/python

import numpy,pylab,random,sqlite3,collections,os,re
from pylab import *
from scipy import linalg
from scipy.ndimage.morphology import binary_erosion
from scipy.ndimage import interpolation,filters
from ocrolib import dbhelper,improc
import cv,random,pyflann
from collections import Counter,defaultdict
from optparse import OptionParser
import shelve
import tables
from tables import *
from ocrolib import docproc

class record:
    def __init__(self,**kw):
        self.__dict__.update(kw)
    def __str__(self):
        return str(self.__dict__)
    
def uencode(s):
    assert len(s)<=4
    result = 0
    for c in s[len(s)-1::-1]: result = (result<<16)|ord(c)
    return result

def udecode(i):
    result = []
    while i!=0:
        result.append(unichr(i&0xffff))
        i >>= 16
    return "".join(result)

def get_images(cname,table='chars'):
    with sqlite3.connect(cname) as db:
        db.row_factory = dbhelper.DbRow
        query = "select * from %s"%table
        rows = db.execute(query)
        for row in rows:
            image = array(improc.pad_by(dbhelper.blob2image(row.image),1),'B')
            try:
                rel = array([float(x) for x in row.rel.split()])
            except:
                rel = None
            yield record(id=row.id,image=image,cls=row.cls,rel=rel,cost=row.cost,
                segid=row.segid,count=row.count,cluster=row.cluster,bbox=row.bbox,
                classes=row.classes,file=row.file)
        del rows

def csnormalize(image,f=0.75):
    bimage = 1*(image>mean([amax(image),amin(image)]))
    w,h = bimage.shape
    [xs,ys] = mgrid[0:w,0:h]
    s = sum(bimage)
    if s<1e-4: return image
    s = 1.0/s
    cx = sum(xs*bimage)*s
    cy = sum(ys*bimage)*s
    sxx = sum((xs-cx)**2*bimage)*s
    sxy = sum((xs-cx)*(ys-cy)*bimage)*s
    syy = sum((ys-cy)**2*bimage)*s
    w,v = eigh(array([[sxx,sxy],[sxy,syy]]))
    l = sqrt(amax(w))
    scale = f*max(image.shape)/(4.0*l)
    m = array([[1.0/scale,0],[0.0,1.0/scale]])
    w,h = image.shape
    c = array([cx,cy])
    d = c-dot(m,array([w/2,h/2]))
    image = interpolation.affine_transform(image,m,offset=d,order=1)
    return image

def table_log(db,*args):
    import time
    db.setNodeAttr("/","LOG_%d"%int(time.time())," ".join(args))

import argparse
parser = argparse.ArgumentParser( description = "Convert character databases in SQLite3 format to HDF5 format.")
parser.add_argument('db',default='training.db',help="db file")
parser.add_argument('-o','--output',default=None,help="hdf5 ouput file")
parser.add_argument('-n','--nimages',type=int,default=2000000000,help="max # images to convert")
parser.add_argument('-r','--pattern',default='.*',help="pattern for characters to transform")
parser.add_argument('-p','--size',type=int,default=32,help="patchsize; 0 stores pickled Python arrays instead")
parser.add_argument('-N','--nonormalize',action="store_true",help="do not perform size normalization")
parser.add_argument('-t','--table',default="chars",help="database table")
parser.add_argument('-g','--nogeometry',action='store_true',help='do not copy over geometry information')
parser.add_argument('-e','--extended',action='store_true',help='copy over extended information')
parser.add_argument('-D','--display',type=int,default=0,help='if non-zero, display characters')
args = parser.parse_args()

assert args.output is not None
assert args.size>5 and args.size<256

h5 = tables.openFile(args.output,"w")
size = args.size

if size==0:
    images = h5.createVLArray(h5.root,'images',ObjectAtom(),filters=Filters(9))
else:
    patches = h5.createEArray(h5.root,'patches',Float32Atom(),shape=(0,size,size),
                              title="characters as patches from "+args.db,
                              filters=Filters(9))

classes = h5.createEArray(h5.root,'classes',Int64Atom(),shape=(0,),filters=tables.Filters(9))
table_log(h5,"%s"%sys.argv)

if not args.nogeometry:
    rel = h5.createEArray(h5.root,'rel',Float32Atom(shape=(3,)),shape=(0,),filters=tables.Filters(9))

if args.extended:
    cclasses = h5.createVLArray(h5.root,'cclasses',StringAtom(120),filters=Filters(9))
    files = h5.createVLArray(h5.root,'files',StringAtom(120),filters=Filters(9))
    bboxes = h5.createEArray(h5.root,'bboxes',Float32Atom(shape=(4,)),shape=(0,),filters=Filters(9))
    costs = h5.createEArray(h5.root,'costs',Float32Atom(),shape=(0,),filters=tables.Filters(9))
    segids = h5.createEArray(h5.root,'segids',Int16Atom(),shape=(0,),filters=tables.Filters(9))
    clusters = h5.createEArray(h5.root,'clusters',Int32Atom(),shape=(0,),filters=tables.Filters(9))

print 'selecting patterns with classes matching: /^'+args.pattern+'$/'
if args.nonormalize:
    print "normalizing by isotropic rescaling into",size,"x",size,"bounding box"
else:
    print "normalizing by moments into",size,"x",size,"bounding box"


try:
    for r in get_images(args.db,table=args.table):
        if len(classes)>=args.nimages: break
        if not re.match(args.pattern+"$",r.cls): continue
        if len(classes)%1000==0: sys.stdout.write("%d\r"%len(classes)); sys.stdout.flush()
        if not args.nogeometry:
            assert rel is not None,"geometry information missing from database; use -g"
            assert len(rel)==len(classes)
            rel.append(r.rel)
        bbox = [int(x) for x in r.bbox.split()]
        if args.extended:
            cclasses.append([r.classes])
            files.append([r.file])
            bboxes.append([bbox])
            costs.append([r.cost])
            segids.append([int(r.segid or 0)])
            clusters.append([int(r.cluster or 0)])
        rshape = r.image.shape
        if size==0:
            images.append(r.image)
        else:
            image = array(r.image,'f')/255.0
            if not args.nonormalize:
                image = docproc.isotropic_rescale(image,size)
                image = csnormalize(image)
            else:
                image = docproc.isotropic_rescale(image,size-2)
                image = improc.pad_by(image,1)
            patches.append([image])
        classes.append([uencode(r.cls)])
        if args.display>0 and len(classes)%args.display==0:
            ion(); gray(); clf(); imshow(image); ginput(1,0.001)
            print r.id,r.cls,rshape,r.rel,r.cost,r.count,r.segid,r.classes,bbox
finally:
    h5.close()

print "done"

